{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "46198a85",
   "metadata": {},
   "source": [
    "# MultiQueryRetriever\n",
    "\n",
    "거리 기반 벡터 데이터베이스 검색은 고차원 공간에서의 쿼리 임베딩(표현)과 '거리'를 기준으로 유사한 임베딩을 가진 문서를 찾는 방식입니다. 하지만 쿼리의 **세부적인 차이나 임베딩이 데이터의 의미를 제대로 포착하지 못할 경우, 검색 결과가 달라질 수** 있습니다. 또한, 이를 수동으로 조정하는 프롬프트 엔지니어링이나 튜닝 작업은 번거로울 수 있습니다.\n",
    "\n",
    "이런 문제를 해결하기 위해, `MultiQueryRetriever` 는 주어진 사용자 입력 쿼리에 대해 다양한 관점에서 여러 쿼리를 자동으로 생성하는 LLM(Language Learning Model)을 활용해 프롬프트 튜닝 과정을 자동화합니다.\n",
    "\n",
    "이 방식은 각각의 쿼리에 대해 관련 문서 집합을 검색하고, 모든 쿼리를 아우르는 고유한 문서들의 합집합을 추출해, 잠재적으로 관련된 더 큰 문서 집합을 얻을 수 있게 해줍니다. \n",
    "\n",
    "여러 관점에서 동일한 질문을 생성함으로써, `MultiQueryRetriever` 는 거리 기반 검색의 제한을 일정 부분 극복하고, 더욱 풍부한 검색 결과를 제공할 수 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ee62e07",
   "metadata": {},
   "outputs": [],
   "source": [
    "# API 키를 환경변수로 관리하기 위한 설정 파일\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "# API 키 정보 로드\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce16f0bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# LangSmith 추적을 설정합니다. https://smith.langchain.com\n",
    "# !pip install langchain-teddynote\n",
    "from langchain_teddynote import logging\n",
    "\n",
    "# 프로젝트 이름을 입력합니다.\n",
    "logging.langsmith(\"CH10-Retriever\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dae75cb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 샘플 벡터DB 구축\n",
    "from langchain_community.document_loaders import WebBaseLoader\n",
    "from langchain.vectorstores import FAISS\n",
    "from langchain_openai import OpenAIEmbeddings\n",
    "from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
    "\n",
    "# 블로그 포스트 로드\n",
    "loader = WebBaseLoader(\n",
    "    \"https://teddylee777.github.io/openai/openai-assistant-tutorial/\", encoding=\"utf-8\"\n",
    ")\n",
    "\n",
    "# 문서 분할\n",
    "text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)\n",
    "docs = loader.load_and_split(text_splitter)\n",
    "\n",
    "# 임베딩 정의\n",
    "openai_embedding = OpenAIEmbeddings()\n",
    "\n",
    "# 벡터DB 생성\n",
    "db = FAISS.from_documents(docs, openai_embedding)\n",
    "\n",
    "# retriever 생성\n",
    "retriever = db.as_retriever()\n",
    "\n",
    "# 문서 검색\n",
    "query = \"OpenAI Assistant API의 Functions 사용법에 대해 알려주세요.\"\n",
    "relevant_docs = retriever.invoke(query)\n",
    "\n",
    "# 검색된 문서의 개수 출력\n",
    "len(relevant_docs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43d88fd7",
   "metadata": {},
   "source": [
    "검색된 결과 중 1개 문서의 내용을 출력합니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8ea9548",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1번 문서를 출력합니다.\n",
    "print(relevant_docs[1].page_content)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8f3224e",
   "metadata": {},
   "source": [
    "## 사용방법\n",
    "\n",
    "`MultiQueryRetriever` 에 사용할 LLM을 지정하고 질의 생성에 사용하면, retriever가 나머지 작업을 처리합니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1637815b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.retrievers.multi_query import MultiQueryRetriever\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "\n",
    "# ChatOpenAI 언어 모델을 초기화합니다. temperature는 0으로 설정합니다.\n",
    "llm = ChatOpenAI(temperature=0, model=\"gpt-4o-mini\")\n",
    "\n",
    "multiquery_retriever = MultiQueryRetriever.from_llm(  # MultiQueryRetriever를 언어 모델을 사용하여 초기화합니다.\n",
    "    # 벡터 데이터베이스의 retriever와 언어 모델을 전달합니다.\n",
    "    retriever=db.as_retriever(),\n",
    "    llm=llm,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbd0b647",
   "metadata": {},
   "source": [
    "아래는 다중 쿼리를 생성하는 중간 과정을 디버깅하기 위하여 실행하는 코드입니다.\n",
    "\n",
    "먼저 `\"langchain.retrievers.multi_query\"` 로거를 가져옵니다. \n",
    "\n",
    "이는 `logging.getLogger()` 함수를 사용하여 수행됩니다. 그 다음, 이 로거의 로그 레벨을 `INFO`로 설정하여, `INFO` 레벨 이상의 로그 메시지만 출력되도록 할 수 있습니다. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "901d1749",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 쿼리에 대한 로깅 설정\n",
    "import logging\n",
    "\n",
    "logging.basicConfig()\n",
    "logging.getLogger(\"langchain.retrievers.multi_query\").setLevel(logging.INFO)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfda8ebb",
   "metadata": {},
   "source": [
    "이 코드는 `retriever_from_llm` 객체의 `invoke` 메서드를 사용하여 주어진 `question`과 관련된 문서를 검색합니다. \n",
    "\n",
    "검색된 문서들은 `unique_docs`라는 변수에 저장되며, 이 변수의 길이를 확인함으로써 검색된 관련 문서의 총 개수를 알 수 있습니다. 이 과정을 통해 사용자의 질문에 대한 관련 정보를 효과적으로 찾아내고 그 양을 파악할 수 있습니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2e305f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 질문을 정의합니다.\n",
    "question = \"OpenAI Assistant API의 Functions 사용법에 대해 알려주세요.\"\n",
    "# 문서 검색\n",
    "relevant_docs = multiquery_retriever.invoke(question)\n",
    "\n",
    "# 검색된 고유한 문서의 개수를 반환합니다.\n",
    "print(\n",
    "    f\"===============\\n검색된 문서 개수: {len(relevant_docs)}\",\n",
    "    end=\"\\n===============\\n\",\n",
    ")\n",
    "\n",
    "# 검색된 문서의 내용을 출력합니다.\n",
    "print(relevant_docs[0].page_content)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81695892",
   "metadata": {},
   "source": [
    "## LCEL Chain 활용하는 방법\n",
    "\n",
    "- 사용자 정의 프롬프트 정의하고, 정의한 프롬프트와 함께 Chain 을 생성합니다.\n",
    "- Chain 은 사용자의 질문을 입력 받으면 (아래의 예제에서는) 5개의 질문을 생성한 뒤 `\"\\n\"` 구분자로 구분하여 생성된 5개 질문을 반환합니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ab98687",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.runnables import RunnablePassthrough\n",
    "from langchain_core.prompts import PromptTemplate\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "\n",
    "# 프롬프트 템플릿을 정의합니다.(5개의 질문을 생성하도록 프롬프트를 작성하였습니다)\n",
    "prompt = PromptTemplate.from_template(\n",
    "    \"\"\"You are an AI language model assistant. \n",
    "Your task is to generate five different versions of the given user question to retrieve relevant documents from a vector database. \n",
    "By generating multiple perspectives on the user question, your goal is to help the user overcome some of the limitations of the distance-based similarity search. \n",
    "Your response should be a list of values separated by new lines, eg: `foo\\nbar\\nbaz\\n`\n",
    "\n",
    "#ORIGINAL QUESTION: \n",
    "{question}\n",
    "\n",
    "#Answer in Korean:\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "# 언어 모델 인스턴스를 생성합니다.\n",
    "llm = ChatOpenAI(temperature=0, model=\"gpt-4o-mini\")\n",
    "\n",
    "# LLMChain을 생성합니다.\n",
    "custom_multiquery_chain = (\n",
    "    {\"question\": RunnablePassthrough()} | prompt | llm | StrOutputParser()\n",
    ")\n",
    "\n",
    "# 질문을 정의합니다.\n",
    "question = \"OpenAI Assistant API의 Functions 사용법에 대해 알려주세요.\"\n",
    "\n",
    "# 체인을 실행하여 생성된 다중 쿼리를 확인합니다.\n",
    "multi_queries = custom_multiquery_chain.invoke(question)\n",
    "# 결과를 확인합니다.(5개 질문 생성)\n",
    "multi_queries"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c6403eb",
   "metadata": {},
   "source": [
    "이전에 생성한 Chain을 `MultiQueryRetriever` 에 전달하여 retrieve 할 수 있습니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f3cac81",
   "metadata": {},
   "outputs": [],
   "source": [
    "multiquery_retriever = MultiQueryRetriever.from_llm(\n",
    "    llm=custom_multiquery_chain, retriever=db.as_retriever()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "086076bb",
   "metadata": {},
   "source": [
    "`MultiQueryRetriever`를 사용하여 문서를 검색하고 결과를 확인합니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6eaffe30",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 결과\n",
    "relevant_docs = multiquery_retriever.invoke(question)\n",
    "\n",
    "# 검색된 고유한 문서의 개수를 반환합니다.\n",
    "print(\n",
    "    f\"===============\\n검색된 문서 개수: {len(relevant_docs)}\",\n",
    "    end=\"\\n===============\\n\",\n",
    ")\n",
    "\n",
    "# 검색된 문서의 내용을 출력합니다.\n",
    "print(relevant_docs[0].page_content)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py-test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
